#!/usr/bin/python3
# ******************************************************************************
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
# licensed under the Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#     http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN 'AS IS' BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR
# PURPOSE.
# See the Mulan PSL v2 for more details.
# ******************************************************************************/
import re
from collections import defaultdict
from typing import Tuple, Dict

from ceres.conf.constant import CommandExitCode
from ceres.function.log import LOGGER
from ceres.function.util import execute_shell_command
from ceres.function.check import PreCheck
from ceres.function.status import PRE_CHECK_ERROR, SUCCESS
from ceres.manages.collect_manage import Collect

__all__ = ["CveScanManage"]


class CveScanManage:
    def __init__(self) -> None:
        self.kernel_filter = None
        self.installed_rpm_info = None
        self.available_hotpatch_key_set = set()

    def cve_scan(self, cve_scan_args: dict) -> Tuple[str, dict]:
        """
        query vulnerability info in the machine

        Args:
            cve_scan_args(dict): e.g
                {
                    check_items: ["network"]
                }

            check_items: Items that need to be checked before execution.

        Returns:
            int: status code
            dict: e.g
                {
                    "check_items": [{
                        "item": "network",
                        "result":False,
                        "log":"check log"
                        }],
                    "unfixed_cves": [{
                        "cve_id": "CVE-2023-1513",
                        "installed_rpm": "kernel-4.19.90-2304.1.0.0131.oe1.x86_64",
                        "available_rpm": "kernel-4.19.90-2304.1.0.0196.oe1.x86_64",
                        "support_way": "coldpatch",
                        }],
                    "fixed_cves": [
                        {
                            "cve_id": "CVE-2023-1112",
                            "installed_rpm":"redis-4.2.5-1.oe2203.x86_64",
                            "fix_way": "hotpatch",
                            "hp_status": "ACCEPTED" //only hotpatch has the field
                        },
                        {
                            "cve_id": "CVE-2023-1112",
                            "installed_rpm":"redis-4.2.5-1.oe2203.x86_64",
                            "fix_way": "coldpatch"
                        }
                    ]
                }
        """
        cve_scan_result = {}
        check_result, items_check_log = PreCheck.execute_check(cve_scan_args.get("check_items"))
        cve_scan_result["check_items"] = items_check_log
        if not check_result:
            LOGGER.info("The pre-check is failed before execute command!")
            return PRE_CHECK_ERROR, cve_scan_result

        self.installed_rpm_info = self._query_installed_rpm()
        self.available_hotpatch_key_set = set()
        self.kernel_filter = cve_scan_args.get("kernel")

        cve_scan_result.update(
            {
                "check_items": items_check_log,
                "unfixed_cves": self._query_unfixed_cves_by_dnf_plugin(),
                "fixed_cves": self._query_fixed_cves_by_dnf_plugin(),
            }
        )
        return SUCCESS, cve_scan_result

    def _query_installed_rpm(self):
        """
        query installed rpm package info

        Returns:
            dict: all rpm info. e.g
                {
                    "kernel":"kernel-5.10.0-60.92.0.116.oe2203.aarch64"
                }
        """
        rpm_info_dict = {}
        # Example of command execution result:
        # openldap:openldap-2.4.50-6.oe1.x86_64
        # kernel:kernel-4.19.90-2310.3.0.0222.oe1.x86_64
        # systemtap-runtime:systemtap-runtime-4.3-2.oe1.x86_64
        # perl-Net-SSLeay:perl-Net-SSLeay-1.88-5.oe1.x86_64
        # powertop:powertop-2.9-12.oe1.x86_64
        # libusbx:libusbx-1.0.23-1.oe1.x86_64
        commands = ["rpm -qa --queryformat '%{NAME}:%{NAME}-%{VERSION}-%{RELEASE}.%{ARCH}\n'"]
        if self.kernel_filter:
            commands.append("grep kernel")

        code, stdout, _ = execute_shell_command(commands)
        if code != CommandExitCode.SUCCEED or not stdout:
            LOGGER.error("query installed packages info failed!")
            return rpm_info_dict

        for line in stdout.splitlines():
            rpm_name, rpm_info = line.split(":", 1)
            rpm_info_dict[rpm_name] = rpm_info

        rpm_info_dict["kernel"] = (
            f"kernel-{Collect.get_current_kernel_version()}" if Collect.get_current_kernel_version() else ""
        )

        LOGGER.debug("query installed rpm package info succeed!")
        return rpm_info_dict

    def _query_unfixed_cves_by_dnf(self) -> list:
        """
        parse unfixed kernel vulnerability info by dnf (coldpatch)

        Return:
            str: command execute result
            list: cve info e.g
                [{
                    "cve_id": "CVE-2023-1513",
                    "installed_rpm": "kernel-4.19.90-2304.1.0.0131.oe1.x86_64",
                    "available_rpm": "kernel-4.19.90-2304.1.0.0196.oe1.x86_64",
                    "support_way": "coldpatch",
                }]

        """
        # Example of command execution result:
        # Last metadata expiration check: 0:26:36 ago on Mon 07 Aug 2023 10:26:32 AM CST.
        # CVE-2021-43976  Important/Sec. kernel-4.19.90-2201.1.0.0132.oe1.x86_64
        # CVE-2021-0941   Important/Sec. kernel-4.19.90-2201.1.0.0132.oe1.x86_64
        # CVE-2021-45469  Important/Sec. kernel-4.19.90-2201.1.0.0132.oe1.x86_64
        # CVE-2021-44733  Important/Sec. kernel-4.19.90-2201.1.0.0132.oe1.x86_64
        unfixed_cves = []
        commands = ["dnf updateinfo list cves"]
        if self.kernel_filter:
            commands.append("grep kernel")
        code, stdout, stderr = execute_shell_command(commands)

        if code != CommandExitCode.SUCCEED:
            LOGGER.error("query unfixed cve info failed by dnf!")
            LOGGER.error(stderr)
            return unfixed_cves

        # Example of regex matching result:
        # [
        # ("CVE-2021-43976", "Important/Sec.", "kernel-4.19.90-2201.1.0.0132.oe1.x86_64"),
        # ("CVE-2021-0941", "Important/Sec.", "kernel-4.19.90-2201.1.0.0132.oe1.x86_64")
        # ]
        pattern = r"(CVE-\d{4}-\d+)\s+([\w+/.]+)\s+(\S+)"
        if self.kernel_filter:
            pattern = r"(CVE-\d{4}-\d+)\s+([\w+/.]+)\s+(kernel-\d\S+)"
        all_cve_info = re.findall(pattern, stdout)
        if not all_cve_info:
            return unfixed_cves

        for cve_id, _, coldpatch in all_cve_info:
            rpm_name = coldpatch.rsplit("-", 2)[0]
            unfixed_cves.append(
                {
                    "cve_id": cve_id,
                    "installed_rpm": self.installed_rpm_info.get(rpm_name),
                    "available_rpm": coldpatch,
                    "support_way": "coldpatch",
                }
            )
        return unfixed_cves

    def _query_unfixed_cves_by_dnf_plugin(self) -> list:
        """
        parse unfixed kernel vulnerability info by dnf hotpatch plugin (hotpatch and coldpatch)

        Return:
            str: command execute result
            list: cve info e.g
                [{
                    "cve_id": "CVE-2023-1513",
                    "installed_rpm": "kernel-4.19.90-2304.1.0.0131.oe1.x86_64",
                    "available_rpm": "kernel-4.19.90-2304.1.0.0196.oe1.x86_64",
                    "support_way": "coldpatch",
                }]
        """

        def generate_single_vulnerability_info(
            rpm_name: str, support_way: str, available_rpm: str, hotpatch: str = None
        ):
            if support_way != "hotpatch":
                return {
                    "cve_id": cve_id,
                    "installed_rpm": self.installed_rpm_info.get(rpm_name),
                    "available_rpm": available_rpm,
                    "support_way": support_way,
                }
            return {
                "cve_id": cve_id,
                "installed_rpm": self.installed_rpm_info.get(rpm_name),
                "available_rpm": hotpatch,
                "support_way": support_way,
            }

        # Example of command execution result:
        # Last metadata expiration check: 0:31:50 ago on Mon 07 Aug 2023 10:26:32 AM CST.
        # CVE-2023-1981   Moderate/Sec.  avahi-libs-0.8-9.oe1.x86_64                     -
        # CVE-2021-42574  Important/Sec. binutils-2.34-19.oe1.x86_64                     -
        # CVE-2023-1513   Important/Sec. kernel-4.19.90-2304.1.0.0196.oe1.x86_64         patch-kernel-4.19.90-2112...
        cve_info_list = []
        commands = ["dnf hot-updateinfo list cves"]
        if self.kernel_filter:
            commands.append("grep kernel")
        code, stdout, stderr = execute_shell_command(commands)
        if code != CommandExitCode.SUCCEED:
            LOGGER.error("query unfixed cve info failed by dnf!")
            LOGGER.error(stderr)
            return cve_info_list

        # Example of regex matching result:
        # [
        # ("CVE-2023-1513", "Important/Sec.", "kernel-4.19.90-2304.1.0.0196.oe1.x86_64", "patch-kernel-4.19.90-2112.."),
        # ("CVE-2021-xxxx", "Important/Sec.", "-", "patch-redis-6.2.5-1-SGL_CVE_2023_1111_CVE_2023_1112-1-1.x86_64")
        # ]
        pattern = r"(CVE-\d{4}-\d+)\s+([\w+/.]+)\s+(\S+|-)\s+(\S+|-)"
        if self.kernel_filter:
            pattern = r"(CVE-\d{4}-\d+)\s+([\w+/.]+)\s+(kernel-\d\S+|-)\s+(patch-kernel-\d\S+|-)"
        all_cve_info = re.findall(pattern, stdout)
        if not all_cve_info:
            return cve_info_list

        coldpatch_key_set, unavailable_hotpatch_key_set = set(), set()
        for cve_id, _, coldpatch, hotpatch in all_cve_info:
            rpm_name = coldpatch.rsplit("-", 2)[0]
            key = f"{cve_id}-{rpm_name}"
            if hotpatch != "-":
                if coldpatch == "-":
                    try:
                        # Example of hotpatch rpm name:
                        # patch-redis-6.2.5-1-SGL_CVE_2023_1111_CVE_2023_1112-1-1.x86_64
                        rpm_name = hotpatch.rsplit("-", 5)[0].split("-", 1)[1]
                    except IndexError as error:
                        LOGGER.warning(error)
                        rpm_name = ""
                    key = f"{cve_id}-{rpm_name}"

                cve_info_list.append(generate_single_vulnerability_info(rpm_name, "hotpatch", coldpatch, hotpatch))
                self.available_hotpatch_key_set.add(key)

            if coldpatch != "-" and key not in coldpatch_key_set:
                cve_info_list.append(generate_single_vulnerability_info(rpm_name, "coldpatch", coldpatch))
                coldpatch_key_set.add(key)

            if (coldpatch == "-") and (hotpatch == "-") and (key not in unavailable_hotpatch_key_set):
                cve_info_list.append(generate_single_vulnerability_info(rpm_name, None, coldpatch))
                unavailable_hotpatch_key_set.add(key)

        return cve_info_list

    def _query_fixed_cves_by_dnf(self) -> list:
        """
        parse the fixed kernel vulnerability info by dnf

        Return:
            str: command execute result
            list: cve info e.g
                [
                    {"cve_id": "CVE-XXXX-XXXX","installed_rpm": "kernel-version-release.arch", "fix_way":"coldpatch"},
                ]

        """
        # Example of command execution result:
        # Last metadata expiration check: 0:26:36 ago on Mon 07 Aug 2023 10:26:32 AM CST.
        # CVE-2021-43976  Important/Sec. kernel-4.19.90-2201.1.0.0132.oe1.x86_64
        # CVE-2021-0941   Important/Sec. kernel-4.19.90-2201.1.0.0132.oe1.x86_64
        # CVE-2021-45469  Important/Sec. kernel-4.19.90-2201.1.0.0132.oe1.x86_64
        # CVE-2021-44733  Important/Sec. kernel-4.19.90-2201.1.0.0132.oe1.x86_64
        fixed_cves = []
        current_kernel_version = Collect.get_current_kernel_version()
        if not current_kernel_version:
            return fixed_cves
        current_kernel_rpm_name = f"kernel-{current_kernel_version}"

        commands = ["dnf updateinfo list cves --installed"]
        if self.kernel_filter:
            commands.append("grep kernel")
        code, stdout, stderr = execute_shell_command(commands)

        if code != CommandExitCode.SUCCEED:
            LOGGER.error("query fixed cve info failed!")
            LOGGER.error(stderr)
            return fixed_cves

        # Example of regex matching result:
        # [
        # ("CVE-2021-43976","Important/Sec.", "kernel-4.19.90-2201.1.0.0132.oe1.x86_64"),
        # ("CVE-2021-0941","Important/Sec.", "kernel-4.19.90-2201.1.0.0132.oe1.x86_64")
        # ]
        pattern = r"(CVE-\d{4}-\d+)\s+([\w+/.]+)\s+(kernel-\d\S+)"
        if self.kernel_filter:
            pattern = r"(CVE-\d{4}-\d+)\s+([\w+/.]+)\s+(\S+)"

        fixed_cves_info = re.findall(pattern, stdout)

        if not fixed_cves_info:
            return fixed_cves

        for cve_id, _, coldpatch in fixed_cves_info:
            install_rpm = self.installed_rpm_info.get(coldpatch.rsplit("-", 2)[0])
            if coldpatch <= current_kernel_rpm_name or coldpatch.rsplit(".", 2)[0] <= install_rpm.rsplit(".", 2)[0]:
                fixed_cves.append({"cve_id": cve_id, "installed_rpm": install_rpm, "fix_way": "coldpatch"})
        return fixed_cves

    def _query_fixed_cves_by_dnf_plugin(self) -> list:
        """
        parse the fixed kernel vulnerability info by dnf plugin

        Return:
            list: hotpatch info list. e.g
                [{"cve_id": "CVE-XXXX-XXXX", "fix_way": "hotpatch", "hp_status": "ACCEPTED", "installed_rpm":"xxxx"}]

        """
        # Example of command execution result:
        # Last metadata expiration check: 0:31:50 ago on Mon 07 Aug 2023 10:26:32 AM CST.
        # CVE-2023-1981   Moderate/Sec.  avahi-libs-0.8-9.oe1.x86_64                     -
        # CVE-2021-42574  Important/Sec. binutils-2.34-19.oe1.x86_64                     -
        # CVE-2023-1513   Important/Sec. kernel-4.19.90-2304.1.0.0196.oe1.x86_64         patch-kernel-4.19.90-2112...
        current_kernel_version = Collect.get_current_kernel_version()
        if not current_kernel_version:
            return []
        current_kernel_rpm_name = f"kernel-{current_kernel_version}"

        commands = ["dnf hot-updateinfo list cves --installed"]
        if self.kernel_filter:
            commands.append("grep kernel")
        code, stdout, stderr = execute_shell_command(commands)
        if code != CommandExitCode.SUCCEED:
            LOGGER.error("query unfixed cve info failed by dnf!")
            LOGGER.error(stderr)
            return []

        # Example of regex matching result:
        # [
        # ("CVE-2023-1513", "Important/Sec.", "kernel-4.19.90-2304.1.0.0196.oe1.x86_64", "patch-kernel-4.19.90-2112.."),
        # ("CVE-2021-xxxx", "Important/Sec.", "-", "patch-redis-6.2.5-1-SGL_CVE_2023_1111_CVE_2023_1112-1-1.x86_64")
        # ]
        hotpatch_status = self._query_applied_hotpatch_status()
        pattern = r"(CVE-\d{4}-\d+)\s+([\w+/.]+)\s+(\S+|-)\s+(patch-\S+|-)"
        if self.kernel_filter:
            pattern = r"(CVE-\d{4}-\d+)\s+([\w+/.]+)\s+(kernel-\d\S+|-)\s+(patch-kernel-\d\S+|-)"
        all_cve_info = re.findall(pattern, stdout)

        cve_info_fixed_by_coldpatch, cve_info_fixed_by_hotpatch, hotpatch_dic = [], [], defaultdict(str)
        for cve_id, _, coldpatch, hotpatch in all_cve_info:
            if hotpatch == "-":
                installed_rpm = self.installed_rpm_info.get(coldpatch.rsplit("-", 2)[0])
                if (self.kernel_filter and coldpatch > current_kernel_rpm_name) or (
                    coldpatch.rsplit(".", 2)[0] > installed_rpm.rsplit(".", 2)[0]
                ):
                    continue
                cve_info_fixed_by_coldpatch.append(
                    {
                        "cve_id": cve_id,
                        "installed_rpm": installed_rpm,
                        "fix_way": "coldpatch",
                    }
                )
            else:
                cve_info_fixed_by_hotpatch.append({"cve_id": cve_id, "fix_way": "hotpatch", "installed_rpm": hotpatch})

                hotpatch_dic_key = hotpatch.rsplit("-", 2)[0]
                if hotpatch_dic_key.endswith("ACC"):
                    hotpatch_dic[hotpatch_dic_key] = max(hotpatch, hotpatch_dic.get(hotpatch_dic_key, hotpatch))

        for cve_info in cve_info_fixed_by_hotpatch:
            hotpatch_dic_key = cve_info["installed_rpm"].rsplit("-", 2)[0]

            if hotpatch_dic_key in hotpatch_dic:
                cve_info["installed_rpm"] = hotpatch_dic[hotpatch_dic_key]
            cve_info["hp_status"] = hotpatch_status.get(cve_info["installed_rpm"].rsplit(".", 1)[0], "")

        return cve_info_fixed_by_coldpatch + cve_info_fixed_by_hotpatch

    def _query_applied_hotpatch_status(self) -> Dict[str, str]:
        """
        query applied hotpatch with its status

        Return:
            dict: key is hotpatch name, value is its status. e.g {"patch-redis-6.2.5-1-ACC-1-3": "ACTIVED"}

        """
        # Example of command execution result:
        # Last metadata expiration check: 0:28:36 ago on Mon 07 Aug 2023 10:26:32 AM CST.
        # CVE-id        base-pkg/hotpatch                                                 status
        # CVE-2023-1111 redis-6.2.5-1/ACC-1-1/redis-benchmark                             ACTIVED
        # CVE-2023-1112 redis-6.2.5-1/ACC-1-1/redis-benchmark                             ACTIVED
        # CVE-2023-1111 redis-6.2.5-1/ACC-1-1/redis-cli                                   ACTIVED
        # CVE-2023-1112 redis-6.2.5-1/ACC-1-1/redis-cli                                   ACTIVED
        # CVE-2023-1111 redis-6.2.5-1/ACC-1-1/redis-server                                NOT-APPLIED
        # CVE-2023-1112 redis-6.2.5-1/ACC-1-1/redis-server                                NOT-APPLIED
        # CVE-2023-2221 redis-6.2.5-1/ACC-1-2/redis-cli                                   NOT-APPLIED
        # CVE-2023-2222 redis-6.2.5-1/ACC-1-2/redis-cli                                   NOT-APPLIED
        # CVE-2023-1111 redis-6.2.5-1/SGL_CVE_2023_1111_CVE_2023_1112-1-1/redis-benchmark NOT-APPLIED
        # CVE-2023-1112 redis-6.2.5-1/SGL_CVE_2023_1111_CVE_2023_1112-1-1/redis-benchmark NOT-APPLIED
        # CVE-2023-1111 redis-6.2.5-1/SGL_CVE_2023_1111_CVE_2023_1112-1-1/redis-cli       NOT-APPLIED
        # CVE-2023-1112 redis-6.2.5-1/SGL_CVE_2023_1111_CVE_2023_1112-1-1/redis-cli       NOT-APPLIED
        # CVE-2023-1111 redis-6.2.5-1/SGL_CVE_2023_1111_CVE_2023_1112-1-1/redis-server    NOT-APPLIED
        # CVE-2023-1112 redis-6.2.5-1/SGL_CVE_2023_1111_CVE_2023_1112-1-1/redis-server    NOT-APPLIED
        result = {}
        code, stdout, stderr = execute_shell_command(["dnf hotpatch --list cves"])
        if code != CommandExitCode.SUCCEED:
            LOGGER.error("query applied hotpatch info failed!")
            LOGGER.error(stderr)
            return result

        # Example of regex matching result:
        # [
        # ("CVE-2023-1112", "redis-6.2.5-1/SGL_CVE_2023_1111_CVE_2023_1112-1-1/redis-server", "NOT-APPLIED"),
        # ("CVE-2023-1111", "redis-6.2.5-1/ACC-1-1/redis-benchmark", "ACTIVED")
        # ]
        pattern = r"(CVE-\d{4}-\d+)\s+([\w\-/.]+)\s+([A-W]+)"
        if self.kernel_filter:
            pattern = r"(CVE-\d{4}-\d+)\s+(kernel-\d[\w\-/.]+)\s+([A-W]+)"

        applied_hotpatch_info_list = re.findall(pattern, stdout)

        if not applied_hotpatch_info_list:
            return result

        record_key_set = set()
        for cve_id, patch_name, hotpatch_status in applied_hotpatch_info_list:
            rpm = patch_name.split("-", 1)[0]
            # Refer to this example, the CVE can be marked as fixed only if all hotpatch are applied.
            # CVE-id        base-pkg/hotpatch                                                 status
            # CVE-2023-1111 redis-6.2.5-1/ACC-1-1/redis-benchmark                             ACTIVED
            # CVE-2023-1111 redis-6.2.5-1/ACC-1-1/redis-cli                                   ACTIVED
            # CVE-2023-1111 redis-6.2.5-1/ACC-1-1/redis-server                                NOT-APPLIED
            record_key = f"{cve_id}-{rpm}"
            if (
                (record_key not in self.available_hotpatch_key_set)
                and (hotpatch_status in ("ACTIVED", "ACCEPTED"))
                and record_key not in record_key_set
            ):
                result[f"patch-{patch_name.rsplit('/',1)[0].replace('/','-')}"] = hotpatch_status
                record_key_set.add(record_key)
        return result
